# Alex F. Gazmararian
# agazmararian@gmail.com
# Simulate data under assumptions about treatment effect size

source("analysis/visibility/analysis_config.R")
source(here("R", "load_functions.R"))
source(here("R", "visibility", "replication_mode.R"))

library(foreach)
library(doParallel)
library(parallel)

# Initialize replication mode
REPLICATION_MODE <- init_replication_mode()
message("=== POWER ANALYSIS ===")
message("Replication mode: ", REPLICATION_MODE)

# Power analysis cache file (used in anonymized mode)
POWER_CACHE_FILE <- here("data", "cache", "power_analysis_cache.rds")

g <- readRDS(here("data", "output", "visibility_analysis.rds"))

# Configuration----
set.seed(10)

# Define variables for power analysis
var.list <- treat.vars

calc_power <- function(var, outcome = "credit_biden_bin", n_cores = NULL, n_sims = 500, decay = .75, covar.list = NULL) {
  
  # Check that covar.list is provided
  if (is.null(covar.list)) {
    stop("covar.list must be provided as an argument")
  }
  
  # Parallel processing setup
  if (is.null(n_cores)) {
    n_cores <- min(max(1, detectCores() - 1), 8)
  }
  
  message(paste0("Using ", n_cores, " cores for parallel processing"))
  message(paste0("Running ", n_sims, " simulations per MDE value"))
  
  # Force complete cleanup of any existing parallel environment
  if (getDoParWorkers() > 1) {
    registerDoSEQ()
    message("Cleaned up existing parallel backend")
  }
  
  # Clear environment and force garbage collection
  gc(verbose = FALSE)
  
  # Create fresh cluster with error handling
  cl <- NULL
  tryCatch({
    # Always create a fresh cluster to avoid variable export issues
    cl <- makeCluster(n_cores, type = "PSOCK", outfile = "")
    registerDoParallel(cl)
    message("Successfully created fresh cluster with ", n_cores, " workers")
  }, error = function(e) {
    message("Failed to create parallel cluster: ", e$message)
    message("Falling back to sequential processing")
    registerDoSEQ()
    cl <<- NULL
  })
  
  # Robust cleanup on exit
  on.exit({
    tryCatch({
      if (!is.null(cl) && exists("cl")) {
        stopCluster(cl)
        message("Cluster stopped successfully")
      }
      registerDoSEQ()
    }, error = function(e) {
      message("Warning during cleanup: ", e$message)
      registerDoSEQ()
    })
  })
  
  mde <- seq(0.01, .2, .005)
  # Assuming a diminishing effect of proximity: spatial decay
  q1 <- decay
  q2 <- q1 ^ 2
  q3 <- q2 ^ 2
  
  # Get control group mean (moved outside loops)
  message("Calculating control group mean")
  y_mean <- g %>%
    filter(.data[[var]] == "Q5" & !is.na(.data[[outcome]])) %>%  # Explicit NA filtering
    summarize(mean = mean(.data[[outcome]], na.rm = TRUE)) %>%
    pull(mean)
  
  message(sprintf("Control group mean for %s: %f", outcome, y_mean))
  
  # Prepare simulation data once (moved outside loops)
  # Include all covariates from covar.list for the regression
  message("Preparing simulation data")
  d.sim <- g %>%
    filter(!is.na(.data[[outcome]]) & !is.na(.data[[var]])) %>%  # Filter out NAs in treatment variable too
    dplyr::select({{var}}, state, response_id, lat_zip, lon_zip, all_of(covar.list)) %>%
    rename(treat := {{var}}) %>%
    # Pre-calculate treatment indicators for faster lookup (no NAs now)
    mutate(
      treat_q1 = treat == "Q1",
      treat_q2 = treat == "Q2", 
      treat_q3 = treat == "Q3",
      treat_q4 = treat == "Q4",
      treat_q5 = treat == "Q5"
    )
  
  n_obs <- nrow(d.sim)
  
  # Pre-calculate treatment probabilities for each MDE (vectorized)
  prob_matrix <- sapply(mde, function(mde_val) {
    c(
      pmin(pmax(y_mean + mde_val, 0), 1),        # Q1
      pmin(pmax(y_mean + mde_val * q1, 0), 1),   # Q2  
      pmin(pmax(y_mean + mde_val * q2, 0), 1),   # Q3
      pmin(pmax(y_mean + mde_val * q3, 0), 1),   # Q4
      y_mean                                     # Q5 (control)
    )
  })
  
  # Build regression formula once outside the loop
  covar_formula <- paste(covar.list, collapse = " + ")
  regression_formula <- as.formula(paste0("y ~ treat + ", covar_formula, " | state"))
  
  # Use parallel processing if cluster was created successfully, otherwise sequential
  use_parallel <- !is.null(cl) && getDoParWorkers() > 1
  
  # If using parallel, pre-load necessary data into workers to avoid export warnings
  if (use_parallel) {
    # Explicitly send data to workers to avoid .export warnings
    tryCatch({
      clusterExport(cl, c('d.sim', 'y_mean', 'covar.list', 'decay', 'n_sims', 'mde'), envir = environment())
      clusterEvalQ(cl, suppressPackageStartupMessages(library(dplyr)))
      clusterEvalQ(cl, suppressPackageStartupMessages(library(fixest)))
      clusterEvalQ(cl, suppressPackageStartupMessages(library(broom)))
      message("Pre-loaded data into ", getDoParWorkers(), " workers")
    }, error = function(e) {
      message("Warning: Could not pre-load data into workers: ", e$message)
    })
  }
  
  # Wrap foreach in error handling
  power.out <- tryCatch({
    if (use_parallel) {
      message("Running power analysis in parallel with ", getDoParWorkers(), " workers")
      foreach(j = seq_along(mde), 
              .combine = 'list',
              .multicombine = TRUE,
              .packages = c('dplyr', 'fixest', 'broom', 'here'),
              .verbose = FALSE) %dopar% {
      
      # Use variables from parent environment directly
      q1_local <- decay
      q2_local <- q1_local ^ 2
      q3_local <- q2_local ^ 2
      
      # Use parent environment data
      d.sim_local <- d.sim
      y_mean_local <- y_mean
      
      n_obs_local <- nrow(d.sim_local)
      
      # Build regression formula
      covar_formula_local <- paste(covar.list, collapse = " + ")
      regression_formula_local <- as.formula(paste0("y ~ treat + ", covar_formula_local, " | state"))
      
      # Calculate probabilities for this MDE
      prob_q1 <- pmin(pmax(y_mean_local + mde[j], 0), 1)
      prob_q2 <- pmin(pmax(y_mean_local + mde[j] * q1_local, 0), 1)
      prob_q3 <- pmin(pmax(y_mean_local + mde[j] * q2_local, 0), 1)
      prob_q4 <- pmin(pmax(y_mean_local + mde[j] * q3_local, 0), 1)
      prob_q5 <- y_mean_local
      
      # Run simulations - use parameter value
      sim.out <- vector("list", n_sims)
      
      for (i in 1:n_sims) {
        y_sim <- numeric(n_obs_local)
        
        # Simulate outcomes by quintile using the pre-calculated treatment indicators
        quintile_1 <- which(d.sim_local$treat_q1)
        quintile_2 <- which(d.sim_local$treat_q2)
        quintile_3 <- which(d.sim_local$treat_q3)
        quintile_4 <- which(d.sim_local$treat_q4)
        quintile_5 <- which(d.sim_local$treat_q5)
        
        y_sim[quintile_1] <- rbinom(length(quintile_1), 1, prob_q1)
        y_sim[quintile_2] <- rbinom(length(quintile_2), 1, prob_q2)
        y_sim[quintile_3] <- rbinom(length(quintile_3), 1, prob_q3)
        y_sim[quintile_4] <- rbinom(length(quintile_4), 1, prob_q4)
        y_sim[quintile_5] <- rbinom(length(quintile_5), 1, prob_q5)
        
        d.sim_local$y <- y_sim
        
        # Try regression with error handling
        result <- tryCatch({
          mod <- suppressWarnings(suppressMessages(feols(regression_formula_local, data = d.sim_local)))
          tidy_result <- broom::tidy(mod)
          tidy_result
        }, error = function(e) {
          data.frame(
            term = c("treat"),
            estimate = NA,
            std.error = NA,
            statistic = NA,
            p.value = NA
          )
        })
        
        sim.out[[i]] <- result
      }
      
      # Process this MDE result
      sim.out.combined <- dplyr::bind_rows(sim.out)
      
      if (nrow(sim.out.combined) > 0) {
        analytical.power <- sim.out.combined %>%
          mutate(hit = ifelse(p.value < 0.05 & estimate > 0, 1, 0)) %>%
          group_by(term) %>%
          summarize(pwr = mean(hit, na.rm = TRUE), .groups = 'drop')
        
        analytical.power$mde <- mde[j]
        return(analytical.power)
      } else {
        return(NULL)
      }
      }
    } else {
      message("Running power analysis sequentially")
      foreach(j = seq_along(mde), 
              .combine = 'list',
              .multicombine = TRUE,
              .packages = c('dplyr', 'fixest', 'broom', 'here'),
              .verbose = FALSE) %do% {
      
      # Re-create needed objects within worker
      mde_local <- seq(0.01, .2, .005)
      q1_local <- decay
      q2_local <- q1_local ^ 2
      q3_local <- q2_local ^ 2
      
      # Use pre-calculated values from main process (avoid file I/O in workers)
      y_mean_local <- y_mean
      d.sim_local <- d.sim
      
      n_obs_local <- nrow(d.sim_local)
      
      # Build regression formula
      covar_formula_local <- paste(covar.list, collapse = " + ")
      regression_formula_local <- as.formula(paste0("y ~ treat + ", covar_formula_local, " | state"))
      
      # Calculate probabilities for this MDE
      prob_q1 <- pmin(pmax(y_mean_local + mde[j], 0), 1)
      prob_q2 <- pmin(pmax(y_mean_local + mde[j] * q1_local, 0), 1)
      prob_q3 <- pmin(pmax(y_mean_local + mde[j] * q2_local, 0), 1)
      prob_q4 <- pmin(pmax(y_mean_local + mde[j] * q3_local, 0), 1)
      prob_q5 <- y_mean_local
      
      # Run simulations - use parameter value
      sim.out <- vector("list", n_sims)
      
      for (i in 1:n_sims) {
        y_sim <- numeric(n_obs_local)
        
        if (sum(d.sim_local$treat_q1) > 0) y_sim[d.sim_local$treat_q1] <- rbinom(sum(d.sim_local$treat_q1), 1, prob_q1)
        if (sum(d.sim_local$treat_q2) > 0) y_sim[d.sim_local$treat_q2] <- rbinom(sum(d.sim_local$treat_q2), 1, prob_q2)
        if (sum(d.sim_local$treat_q3) > 0) y_sim[d.sim_local$treat_q3] <- rbinom(sum(d.sim_local$treat_q3), 1, prob_q3)
        if (sum(d.sim_local$treat_q4) > 0) y_sim[d.sim_local$treat_q4] <- rbinom(sum(d.sim_local$treat_q4), 1, prob_q4)
        if (sum(d.sim_local$treat_q5) > 0) y_sim[d.sim_local$treat_q5] <- rbinom(sum(d.sim_local$treat_q5), 1, prob_q5)
        
        d.sim.out <- d.sim_local
        d.sim.out$y <- y_sim
        
        result <- tryCatch({
          m <- suppressWarnings(suppressMessages(feols(
            fml = regression_formula_local, 
            data = d.sim.out,
            vcov = vcov_conley(cutoff = 400, lat = "lat_zip", lon = "lon_zip")
          )))
          
          df.out <- broom::tidy(m)
          df.out$iter <- i
          df.out$mde_idx <- j
          df.out
        }, error = function(e) {
          data.frame(term = character(0), estimate = numeric(0), 
                    std.error = numeric(0), statistic = numeric(0),
                    p.value = numeric(0), iter = integer(0), mde_idx = integer(0))
        })
        
        sim.out[[i]] <- result
      }
      
      # Process this MDE result
      sim.out.combined <- dplyr::bind_rows(sim.out)
      
      if (nrow(sim.out.combined) > 0) {
        analytical.power <- sim.out.combined %>%
          mutate(hit = ifelse(p.value < 0.05 & estimate > 0, 1, 0)) %>%
          group_by(term) %>%
          summarize(pwr = mean(hit, na.rm = TRUE), .groups = 'drop')
        
        analytical.power$mde <- mde[j]
        return(analytical.power)
      } else {
        return(NULL)
      }
      }
    }
  }, error = function(e) {
    message("Error in foreach loop: ", e$message)
    message("Returning empty result")
    list()
  })
  
  # Remove NULL entries
  power.out.clean <- power.out[!sapply(power.out, is.null)]
  
  # Combine results with proper handling
  if (length(power.out.clean) > 0) {
    # Ensure all elements are data frames before combining
    valid_dfs <- power.out.clean[sapply(power.out.clean, is.data.frame)]
    if (length(valid_dfs) > 0) {
      power.out.df <- do.call(rbind, valid_dfs)
      # Ensure it's a proper data frame
      power.out.df <- as.data.frame(power.out.df)
      rownames(power.out.df) <- NULL  # Clean row names
    } else {
      # Create empty data frame with expected structure
      power.out.df <- data.frame(
        term = character(0),
        pwr = numeric(0),
        mde = numeric(0),
        stringsAsFactors = FALSE
      )
    }
  } else {
    # Create empty data frame with expected structure
    power.out.df <- data.frame(
      term = character(0),
      pwr = numeric(0),
      mde = numeric(0),
      stringsAsFactors = FALSE
    )
  }
  
  # Add model column safely
  if (nrow(power.out.df) > 0) {
    power.out.df$model <- var
  } else {
    power.out.df$model <- character(0)
  }
  
  return(power.out.df)
}

# Helper function for running power analysis with custom parallelization settings
run_power_analysis <- function(var_list, outcome = "credit_biden_bin", n_cores = NULL, n_sims = 500, decay = .75, covar.list = NULL) {
  # Check that covar.list is provided
  if (is.null(covar.list)) {
    stop("covar.list must be provided as an argument")
  }
  
  # Ensure completely clean parallel environment
  if (getDoParWorkers() > 1) {
    registerDoSEQ()  # Reset to sequential
    gc(verbose = FALSE)  # Force garbage collection
  }
  
  # Process each variable with full cleanup between iterations
  result <- list()
  for (i in seq_along(var_list)) {
    message(sprintf("Processing variable %d of %d: %s", i, length(var_list), var_list[i]))
    
    # Force cleanup before each variable analysis
    if (i > 1) {
      if (getDoParWorkers() > 1) {
        registerDoSEQ()
      }
      gc(verbose = FALSE)
      Sys.sleep(0.2)  # Longer delay between variables
    }

    # Run power analysis for this variable
    result[[i]] <- calc_power(var_list[i], outcome = outcome, n_cores = n_cores, n_sims = n_sims, decay = decay, covar.list = covar.list)
    
    # Explicit cleanup after each variable
    gc(verbose = FALSE)
  }
  
  return(result)
}

# Create power plot
create_power_plot <- function (df.in, plot.title) {
  pwr.q1 <- df.in %>%
    filter(term == "treatQ1")
  
  # Find the minimum MDE at 80% power for labeling
  mde_at_80 <- pwr.q1 %>%
    filter(pwr >= .8) %>%
    group_by(model) %>%
    slice_min(pwr) %>%
    ungroup()
  
  # Create the plot
  plot.out <- pwr.q1 %>%
    ggplot(aes(x = mde, y = pwr)) +
    geom_line(linewidth = 1) +
    geom_hline(yintercept = .8, color = "blue") +
    theme_classic(base_size = 10) +
    labs(
      title = plot.title,
      x = "Minimum detectable effect (MDE) at 5% significance level",
      y = "Power") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1), expand = c(0, 0), limits = c(0, 1.05)) +
    scale_x_continuous(expand = c(0, 0), limits = c(-.01, .21)) +
    theme(
      axis.ticks = element_blank()
    )
  
  # Add MDE label only if we found a point at 80% power
  if (nrow(mde_at_80) > 0) {
    mde_label <- paste0("MDE: ", signif(mde_at_80$mde[1], 2))
    plot.out <- plot.out +
      annotate("text", x = .02, y = .75, 
               label = mde_label, 
               color = "blue")
  }
  
  plot.out
}

# Run power analysis----
# Each variable is processed sequentially, but within each variable,
# the MDE iterations are parallelized for maximum efficiency
# Define the treatment variables to analyze for power analysis from configuration

re_label <- tools::toTitleCase(gsub("_", " ", treat.labels[1]))
mfg_label <- tools::toTitleCase(gsub("_", " ", treat.labels[2]))

# In anonymized mode, load cached results instead of running simulations
# (simulations require lat/lon for Conley SEs)
if (REPLICATION_MODE == "anonymized") {
  message("ANONYMIZED MODE: Loading cached power analysis results")
  
  if (!file.exists(POWER_CACHE_FILE)) {
    stop("Anonymized mode requires cached power analysis results. File not found: ", POWER_CACHE_FILE)
  }
  
  power_cache <- readRDS(POWER_CACHE_FILE)
  pwr.credit <- power_cache$credit
  pwr.recog <- power_cache$recog
  pwr.benefit <- power_cache$benefit
  
} else {
  message("FULL MODE: Running power simulations")
  
  ## Credit attribution----
  pwr.credit <- run_power_analysis(var.list, outcome = "credit_biden_bin", covar.list = covar.list, n_sims = 1000)
  
  ## Recognition----
  pwr.recog <- run_power_analysis(var.list, outcome = "greenproj_bin", covar.list = covar.list, n_sims = 1000)
  
  ## Benefits----
  pwr.benefit <- run_power_analysis(var.list, outcome = "greenbenefit_bin", covar.list = covar.list, n_sims = 1000)
  
  # Cache results for future anonymized runs
  power_cache <- list(
    credit = pwr.credit,
    recog = pwr.recog,
    benefit = pwr.benefit
  )
  dir.create(dirname(POWER_CACHE_FILE), recursive = TRUE, showWarnings = FALSE)
  saveRDS(power_cache, POWER_CACHE_FILE)
  message("Cached power analysis results to: ", POWER_CACHE_FILE)
}

# Create plots (works in both modes)
p.re.credit <- create_power_plot(pwr.credit[[1]], paste(re_label, "proximity"))
p.mfg.credit <- create_power_plot(pwr.credit[[2]], paste(mfg_label, "proximity"))

plot.credit <- p.re.credit + p.mfg.credit + 
  plot_layout(ncol = 2) +
    plot_annotation(tag_levels = "A")

save_pnas_pdf(
  plot.credit,
  here("output", "pnas", "figures", "fig_S2_power_analysis_credit.pdf"),
  width = "double",
  height = 6
)
message(glue::glue("Saved power analysis plot for credit attribution to {here('output', 'pnas', 'figures')}"))

p.re.recog <- create_power_plot(pwr.recog[[1]], paste(re_label, "proximity"))
p.mfg.recog <- create_power_plot(pwr.recog[[2]], paste(mfg_label, "proximity"))

plot.recog <- p.re.recog + p.mfg.recog + 
  plot_layout(ncol = 2) +
  plot_annotation(tag_levels = "A")

save_pnas_pdf(
  plot.recog,
  here("output", "pnas", "figures", "fig_S3_power_analysis_recog.pdf"),
  width = "double",
  height = 6
)
message(glue::glue("Saved power analysis plot for recognition to {here('output', 'pnas', 'figures')}"))

p.re.benefit <- create_power_plot(pwr.benefit[[1]], paste(re_label, "proximity"))
p.mfg.benefit <- create_power_plot(pwr.benefit[[2]], paste(mfg_label, "proximity"))

plot.benefit <- p.re.benefit + p.mfg.benefit + 
  plot_layout(ncol = 2) +
  plot_annotation(tag_levels = "A")

save_pnas_pdf(
  plot.benefit,
  here("output", "pnas", "figures", "fig_S4_power_analysis_benefit.pdf"),
  width = "double",
  height = 6
)
message(glue::glue("Saved power analysis plot for benefits to {here('output', 'pnas', 'figures')}"))

# Final cleanup----
# Ensure complete cleanup of all parallel resources
if (getDoParWorkers() > 1) {
  registerDoSEQ()
  message("Final parallel cleanup completed")
}

# Force complete garbage collection
gc(verbose = FALSE)

message("Power analysis complete")